{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `nn.pad`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`op.nn.pad` 函数，用于对输入数据进行填充。下面是对该函数的解读：\n",
    "\n",
    "该函数接受以下参数：\n",
    "- `data`：输入数据的表达式（{class}`tvm.relay.Expr`类型）\n",
    "- `pad_width`：每个轴要填充的宽度，以元组的形式给出，格式为 `((before_1, after_1), ..., (before_N, after_N))`\n",
    "- `pad_value`：（可选）填充的值，默认为 `0`\n",
    "- `pad_mode`：（可选）填充模式，可以是 `'constant'`、`'edge'` 或 `'reflect'`，分别表示使用常量值、边缘值或反射值进行填充\n",
    "\n",
    "函数首先检查`pad_width`和`pad_value`的类型，如果它们不是预期的类型，则进行相应的转换。然后根据 `pad_width` 的类型选择不同的填充方式，并返回计算结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.relay import op\n",
    "from tvm.relay.testing import run_opt_pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> (Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32], Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>pad(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #AA22FF; font-weight: bold\">-</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">*/</span>, pad_width<span style=\"color: #AA22FF; font-weight: bold\">=</span>[[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>], [<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">1</span>], [<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], [<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>]]) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>pad(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #AA22FF; font-weight: bold\">-</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">*/</span>, pad_width<span style=\"color: #AA22FF; font-weight: bold\">=</span>[[<span style=\"color: #AA22FF; font-weight: bold\">-</span><span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>], [<span style=\"color: #008000\">0</span>, <span style=\"color: #AA22FF; font-weight: bold\">-</span><span style=\"color: #008000\">1</span>], [<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], [<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>]]) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  (<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>(Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32], Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), float32]) <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始数据((1, 2, 1, 1)): \n",
      "[[[[0.]]\n",
      "\n",
      "  [[1.]]]]\n",
      "最终数据((2, 3, 1, 1)): \n",
      "[[[[-1.]]\n",
      "\n",
      "  [[-1.]]\n",
      "\n",
      "  [[-1.]]]\n",
      "\n",
      "\n",
      " [[[ 0.]]\n",
      "\n",
      "  [[ 1.]]\n",
      "\n",
      "  [[-1.]]]]\n"
     ]
    }
   ],
   "source": [
    "dshape = 1, 2, 1, 1\n",
    "pad_width = [(1, 0), (0, 1), (0, 0), (0, 0)]\n",
    "x = relay.var(\"x\", shape=dshape)\n",
    "y = op.nn.pad(x, pad_width, pad_value=-1, pad_mode='constant')\n",
    "pad_width[:2] = [(-1, 0), (0, -1)]\n",
    "x_ = op.nn.pad(y, pad_width, pad_value=-1, pad_mode='constant') # 移除 pad\n",
    "t = relay.Tuple([y, x_])\n",
    "func = relay.Function([x], t)\n",
    "func = run_opt_pass(func, relay.transform.InferType())\n",
    "tvm.IRModule.from_expr(func).show()\n",
    "intrp = relay.create_executor(\"graph\", device=tvm.cpu(0), target=\"llvm\")\n",
    "\n",
    "data_np = np.arange(np.prod(dshape)).reshape(dshape).astype(\"float32\")\n",
    "print(f\"原始数据({data_np.shape}): \\n{data_np}\")\n",
    "op_res, new_data = intrp.evaluate(func)(data_np)\n",
    "print(f\"最终数据({op_res.shape}): \\n{op_res}\")\n",
    "np.testing.assert_allclose(data_np, new_data.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
